#!/usr/bin/env python
"""Developed against Rnor6 RefSeq"""
import argparse
from BCBio import GFF
from collections import Counter, OrderedDict


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('input_gff3', help='RefSeq GFF3 file')
    parser.add_argument('output_gff3', help='Output GFF3')
    return parser.parse_args()


def get_unique_id(i, c, update_id=True):
    if i not in c:
        c[i] = 0
        return i
    if update_id is True:
        c[i] += 1
    if c[i] > 0:
        return f"{i}.{c[i]}"
    return i


def add_tags_to_feature(feature, noid_count, tx_counts, gene_counts, parent_feature=None):
    if feature.type not in ['ncRNA', 'tRNA', 'rRNA', 'pseudogene', 'lnc_RNA',
                   'rRNA', 'mRNA', 'snRNA', 'snoRNA', 'scRNA', 'antisense_RNA', 'guide_RNA',
                   'C_gene_segment', 'J_gene_segment', 'V_gene_segment',
                   'primary_transcript', 'miRNA', 'transcript', 'CDS', 'exon', 'gene',
                   'SRP_RNA', 'telomerase_RNA']:
        return
    if 'dbxref' in feature.qualifiers:
        xref = feature.qualifiers['dbxref']
    elif 'Dbxref' in feature.qualifiers:
        xref = feature.qualifiers['Dbxref']
    else:  # hacky fix
        xref = ['GeneID:NoID-{}'.format(noid_count)]
        noid_count += 1
    ids = dict([x.split(':', 1) for x in xref])
    if feature.type == 'gene':
        assert feature == parent_feature
        feature.qualifiers['gene_id'] = [get_unique_id(ids['GeneID'], gene_counts)]
        feature.qualifiers['gene_name'] = feature.qualifiers.get('gene', feature.qualifiers.get('product', feature.qualifiers['gene_id']))
        assert 'gene_biotype' in feature.qualifiers
        return
    feature.qualifiers['gene_id'] = [get_unique_id(ids['GeneID'], gene_counts, update_id=False)]
    feature.qualifiers['gene_name'] = feature.qualifiers.get('gene', feature.qualifiers.get('product', feature.qualifiers['gene_id']))
    tx_id = feature.qualifiers.get('transcript_id', feature.qualifiers['ID'])[0]
    if '-' in tx_id:
        tx_id = tx_id.split('-', 1)[1]
    if feature.type not in ['CDS', 'exon']:
        # is a transcript; check for uniqueness
        tx_id = get_unique_id(tx_id, tx_counts)
    else:
        tx_id = get_unique_id(tx_id, tx_counts, update_id=False)
    feature.qualifiers['transcript_id'] = [tx_id]
    feature.qualifiers['transcript_name'] = feature.qualifiers.get('Name', feature.qualifiers['transcript_id'])
    # figure out gene biotype
    if 'gene_biotype' in parent_feature.qualifiers:
        gene_biotype = parent_feature.qualifiers['gene_biotype']
    else:
        gene_biotype = [feature.type]

    # figure out transcript biotype
    if feature.type == 'primary_transcript':
        transcript_biotype = ['miRNA']
    elif 'segment' in feature.type:
        transcript_biotype = [feature.type]
        feature.type = 'ncRNA'
    elif feature.type == 'lnc_RNA':
        feature.type = 'lncRNA'
        transcript_biotype = [feature.type]
    else:
        transcript_biotype = gene_biotype
    feature.qualifiers['gene_biotype'] = gene_biotype
    feature.qualifiers['transcript_biotype'] = transcript_biotype


def construct_new_qualifiers(feature):
    new_qualifiers = OrderedDict()
    for key, val in feature.qualifiers.items():
        # no upper case keys unless it is ID or Parent or Name
        if key not in ['ID', 'Parent', 'Name']:
            key = key.lower()
        # collapse to a single item
        # replace all semicolons
        if len(val) > 1:
            val = [' '.join([x.replace(';', '%3B').replace('=', '%3D') for x in val])]
        new_qualifiers[key] = val
    # clean up and make parseable
    for key, val in new_qualifiers.items():
        if sum(len(x) for x in val) == 0:
            new_qualifiers[key] = 'True'
    return new_qualifiers


def feature_traversal(feature):
    yield feature
    for sub_feature in feature.sub_features:
        yield from feature_traversal(sub_feature)


if __name__ == '__main__':
    args = parse_args()
    noid_count = 1
    tx_counts = Counter()
    gene_counts = Counter()
    records = list(GFF.parse(args.input_gff3))
    for seqrecord in records:
        for parent_feature in seqrecord.features:
            for feature in feature_traversal(parent_feature):
                try:
                    add_tags_to_feature(feature, noid_count, tx_counts, gene_counts, parent_feature)
                except KeyError:
                    assert False, feature.qualifiers
                if feature.type == 'CDS' and parent_feature.type == 'gene':
                    # convert gene into fake transcript; we lack a parent
                    parent_feature.qualifiers['transcript_biotype'] = ['protein_coding']
                    parent_feature.qualifiers['transcript_id'] = feature.qualifiers['transcript_id']
                    parent_feature.qualifiers['transcript_name'] = feature.qualifiers['transcript_name']
                new_qualifiers = construct_new_qualifiers(feature)
                feature.qualifiers = new_qualifiers

    with open(args.output_gff3, 'w') as fh:
        GFF.write(records, fh)
